Exploring the Google QuickDraw Dataset with SketchRNN (Part 3)
This is the third part in a series of notes on my exploration of the recently released Google QuickDraw dataset, using the concurrently released SketchRNN model.
Again, I've discarded the markdown cells or codeblocks that were intended to explain or demonstrate something, retaining only the code I need to run the experiments in this notebook. Everthing until the section Principal Component Analysis in the Latent Space was copied directly from previous notebooks. Here are links to the first and second note.
The QuickDraw dataset is curated from the millions of drawings contributed by over 15 million people around the world who participated in the "Quick, Draw!" A.I. Experiment, in which they were given the challenge of drawing objects belonging to a particular class (such as "cat") in under 20 seconds.
SketchRNN is a very impressive generative model that was trained to produce vector drawings using this dataset. It was of particular interest to me because it cleverly combines many of the latest tools and techniques recently developed in machine learning, such as Variational Autoencoders, HyperLSTMs (a HyperNetwork for LSTM), Autoregressive models, Layer Normalization, Recurrent Dropout, the Adam optimizer, and others.
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import tensorflow as tf
from matplotlib.animation import FuncAnimation
from matplotlib.path import Path
from matplotlib import rc
from sklearn.decomposition import PCA
from itertools import product
from six.moves import map, zip
from magenta.models.sketch_rnn.sketch_rnn_train import \
(load_env,
load_checkpoint,
reset_graph,
download_pretrained_models,
PRETRAINED_MODELS_URL)
from magenta.models.sketch_rnn.model import Model, sample
from magenta.models.sketch_rnn.utils import (lerp,
slerp,
get_bounds,
to_big_strokes,
to_normal_strokes)
# For inine display of animation
# equivalent to rcParams['animation.html'] = 'html5'
rc('animation', html='html5')
# set numpy output to something sensible
np.set_printoptions(precision=8,
edgeitems=6,
linewidth=200,
suppress=True)
tf.logging.info("TensorFlow Version: {}".format(tf.__version__))
Getting the Pre-Trained Models and Data¶
DATA_DIR = ('http://github.com/hardmaru/sketch-rnn-datasets/'
'raw/master/aaron_sheep/')
MODELS_ROOT_DIR = '/tmp/sketch_rnn/models'
DATA_DIR
PRETRAINED_MODELS_URL
download_pretrained_models(
models_root_dir=MODELS_ROOT_DIR,
pretrained_models_url=PRETRAINED_MODELS_URL)
We look at the layer normalized model trained on the aaron_sheep dataset for now.
MODEL_DIR = MODELS_ROOT_DIR + '/aaron_sheep/layer_norm'
(train_set,
valid_set,
test_set,
hps_model,
eval_hps_model,
sample_hps_model) = load_env(DATA_DIR, MODEL_DIR)
class SketchPath(Path):
def __init__(self, data, factor=.2, *args, **kwargs):
vertices = np.cumsum(data[::, :-1], axis=0) / factor
codes = np.roll(self.to_code(data[::,-1].astype(int)),
shift=1)
codes[0] = Path.MOVETO
super(SketchPath, self).__init__(vertices,
codes,
*args,
**kwargs)
@staticmethod
def to_code(cmd):
# if cmd == 0, the code is LINETO
# if cmd == 1, the code is MOVETO (which is LINETO - 1)
return Path.LINETO - cmd
def draw(sketch_data, factor=.2, pad=(10, 10), ax=None):
if ax is None:
ax = plt.gca()
x_pad, y_pad = pad
x_pad //= 2
y_pad //= 2
x_min, x_max, y_min, y_max = get_bounds(data=sketch_data,
factor=factor)
ax.set_xlim(x_min-x_pad, x_max+x_pad)
ax.set_ylim(y_max+y_pad, y_min-y_pad)
sketch = SketchPath(sketch_data)
patch = patches.PathPatch(sketch, facecolor='none')
ax.add_patch(patch)
Load pre-trained models¶
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess=sess, checkpoint_path=MODEL_DIR)
def encode(input_strokes):
strokes = to_big_strokes(input_strokes).tolist()
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
z = sess.run(eval_model.batch_z,
feed_dict={
eval_model.input_data: [strokes],
eval_model.sequence_lengths: seq_len})[0]
return z
def decode(z_input=None, temperature=.1, factor=.2):
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(
sess,
sample_model,
seq_len=eval_model.hps.max_seq_len,
temperature=temperature, z=z)
return to_normal_strokes(sample_strokes)
Principal Component Analysis in the Latent Space¶
What do you call a baby eigensheep? A lamb, duh.
We encode all of the sketches in the test set into their learned 128-dimensional latent space representations.
Z = np.vstack(map(encode, test_set.strokes))
Z.shape
Then, we find the top two principal axes that represent the direction of maximum variance in the data encoded in the latent space.
pca = PCA(n_components=2)
pca.fit(Z)
The two components each account for about 2% of the variance
pca.explained_variance_ratio_
Let's project the data from the 128-dimensional latent space to the lower 2-dimensional space spanned by the first 2 principal components
Z_pca = pca.transform(Z)
Z_pca.shape
fig, ax = plt.subplots()
ax.scatter(*Z_pca.T)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.show()
We'd like to visualize the original sketches at their corresponding points on this plot. Each point corresponds to the latent code of a sketch, reduced to 2 dimensions. However, the plot is slightly too dense to fit sufficiently large sketches without overlapping them. Therefore, we restrict our attention to a smaller region that encompasses 80% of the data points, discarding those outside of the 5th and 95th percentiles in both aces. The blue shaded rectangle highlights our region of interest.
((pc1_min, pc2_min),
(pc1_max, pc2_max)) = np.percentile(Z_pca, q=[5, 95], axis=0)
roi_rect = patches.Rectangle(xy=(pc1_min, pc2_min),
width=pc1_max-pc1_min,
height=pc2_max-pc2_min, alpha=.4)
fig, ax = plt.subplots()
ax.scatter(*Z_pca.T)
ax.add_patch(roi_rect)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.show()
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xlim(pc1_min, pc1_max)
ax.set_ylim(pc2_min, pc2_max)
for i, sketch in enumerate(test_set.strokes):
sketch_path = SketchPath(sketch, factor=7e+1)
sketch_path.vertices[::,1] *= -1
sketch_path.vertices += Z_pca[i]
patch = patches.PathPatch(sketch_path, facecolor='none')
ax.add_patch(patch)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.savefig("../../files/sketchrnn/aaron_sheep_pca.svg",
format="svg")
Remark: There is a far more clever way to produce this plot that involves the use of matplotlib Transformations and the Collections API, by instantiating PathCollection with the keyword arguments offets defined as the array of projected locations (in this instance Z_pca). However, I was unable to get this to work (maybe I was working in the wrong coordinate system?). In either case, if you find a slicker way to produce this plot, please let me know. I'd love to learn about your approach!
Linear Interpolation in PCA¶
Linear interpolation in the subspaced spanned by the first 2 principal axes.
pc1 = lerp(pc1_min, pc1_max, np.linspace(0, 1, 10))
pc2 = lerp(pc2_min, pc2_max, np.linspace(0, 1, 10))
pc1_mesh, pc2_mesh = np.meshgrid(pc1, pc2)
fig, ax = plt.subplots()
ax.set_xlim(pc1_min-1, pc1_max+1)
ax.set_ylim(pc2_min-1, pc2_max+1)
ax.scatter(*Z_pca.T)
ax.scatter(pc1_mesh, pc2_mesh)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.show()
np.dstack((pc1_mesh, pc2_mesh)).shape
z_grid = np.apply_along_axis(pca.inverse_transform,
arr=np.dstack((pc1_mesh, pc2_mesh)),
axis=2)
z_grid.shape
fig, ax_arr = plt.subplots(nrows=10,
ncols=10,
figsize=(8, 8),
subplot_kw=dict(xticks=[],
yticks=[],
frame_on=False))
fig.tight_layout()
for i, ax_row in enumerate(ax_arr):
for j, ax in enumerate(ax_row):
draw(decode(z_grid[i,j], temperature=.1), ax=ax)
ax.axis('off')
plt.show()
fig, ax_arr = plt.subplots(nrows=10,
ncols=10,
figsize=(8, 8),
subplot_kw=dict(xticks=[],
yticks=[],
frame_on=False))
fig.tight_layout()
for i, ax_row in enumerate(ax_arr):
for j, ax in enumerate(ax_row):
draw(decode(z_grid[i,j], temperature=.6), ax=ax)
ax.axis('off')
plt.show()
Eigensheep Decomposition¶
z_pc1, z_pc2 = pca.components_
pca.explained_variance_ratio_
pca.explained_variance_
fig, ax_arr = plt.subplots(nrows=5,
ncols=10,
figsize=(8, 4),
subplot_kw=dict(xticks=[],
yticks=[],
frame_on=False))
fig.tight_layout()
for row_num, ax_row in enumerate(ax_arr):
for col_num, ax in enumerate(ax_row):
t = (col_num + 1) / 10.
draw(decode(z_pc1, temperature=t), ax=ax)
if row_num + 1 == len(ax_arr):
ax.set_xlabel(r'$\tau={}$'.format(t))
plt.show()
fig, ax_arr = plt.subplots(nrows=5,
ncols=10,
figsize=(8, 4),
subplot_kw=dict(xticks=[],
yticks=[],
frame_on=False))
fig.tight_layout()
for row_num, ax_row in enumerate(ax_arr):
for col_num, ax in enumerate(ax_row):
t = (col_num + 1) / 10.
draw(decode(z_pc2, temperature=t), ax=ax)
if row_num + 1 == len(ax_arr):
ax.set_xlabel(r'$\tau={}$'.format(t))
plt.show()
t-SNE Visualization¶
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, n_iter=5000)
Z_tsne = tsne.fit_transform(Z)
fig, ax = plt.subplots()
ax.scatter(*Z_tsne.T)
ax.set_xlabel('$c_1$')
ax.set_ylabel('$c_2$')
plt.show()
tsne.kl_divergence_
((c1_min, c2_min),
(c1_max, c2_max)) = np.percentile(Z_tsne, q=[5, 95], axis=0)
roi_rect = patches.Rectangle(xy=(c1_min, c2_min),
width=c1_max-c1_min,
height=c2_max-c2_min, alpha=.4)
fig, ax = plt.subplots()
ax.scatter(*Z_tsne.T)
ax.add_patch(roi_rect)
ax.set_xlabel('$c_1$')
ax.set_ylabel('$c_2$')
plt.show()
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xlim(c1_min, c1_max)
ax.set_ylim(c2_min, c2_max)
for i, sketch in enumerate(test_set.strokes):
sketch_path = SketchPath(sketch, factor=2.)
sketch_path.vertices[::,1] *= -1
sketch_path.vertices += Z_tsne[i]
patch = patches.PathPatch(sketch_path, facecolor='none')
ax.add_patch(patch)
ax.axis('off')
plt.savefig("../../files/tsne.svg", format="svg")

